{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distribution Strategy Design Pattern\n",
    "\n",
    "This notebook demonstrates how to use distributed training with Keras."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow import feature_column as fc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine CSV, label, and key columns\n",
    "# Create list of string column headers, make sure order matches.\n",
    "CSV_COLUMNS = [\"weight_pounds\",\n",
    "               \"is_male\",\n",
    "               \"mother_age\",\n",
    "               \"plurality\",\n",
    "               \"gestation_weeks\",\n",
    "               \"mother_race\"]\n",
    "\n",
    "# Add string name for label column\n",
    "LABEL_COLUMN = \"weight_pounds\"\n",
    "\n",
    "# Set default values for each CSV column as a list of lists.\n",
    "# Treat is_male and plurality as strings.\n",
    "DEFAULTS = [[0.0], [\"null\"], [0.0], [\"null\"], [0.0], [\"null\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def features_and_labels(row_data):\n",
    "    \"\"\"Splits features and labels from feature dictionary.\n",
    "    Args:\n",
    "        row_data: Dictionary of CSV column names and tensor values.\n",
    "    Returns:\n",
    "        Dictionary of feature tensors and label tensor.\n",
    "    \"\"\"\n",
    "    label = row_data.pop(LABEL_COLUMN)\n",
    "\n",
    "    return row_data, label\n",
    "\n",
    "\n",
    "def load_dataset(pattern, batch_size=1, mode=tf.estimator.ModeKeys.EVAL):\n",
    "    \"\"\"Loads dataset using the tf.data API from CSV files.\n",
    "    Args:\n",
    "        pattern: str, file pattern to glob into list of files.\n",
    "        batch_size: int, the number of examples per batch.\n",
    "        mode: tf.estimator.ModeKeys to determine if training or evaluating.\n",
    "    Returns:\n",
    "        `Dataset` object.\n",
    "    \"\"\"\n",
    "    # Make a CSV dataset\n",
    "    dataset = tf.data.experimental.make_csv_dataset(\n",
    "        file_pattern=pattern,\n",
    "        batch_size=batch_size,\n",
    "        column_names=CSV_COLUMNS,\n",
    "        column_defaults=DEFAULTS)\n",
    "\n",
    "    # Map dataset to features and label\n",
    "    dataset = dataset.map(map_func=features_and_labels)  # features, label\n",
    "\n",
    "    # Shuffle and repeat for training\n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        dataset = dataset.shuffle(buffer_size=1000).repeat()\n",
    "\n",
    "    # Take advantage of multi-threading; 1=AUTOTUNE\n",
    "    dataset = dataset.prefetch(buffer_size=1)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build model as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_input_layers():\n",
    "    \"\"\"Creates dictionary of input layers for each feature.\n",
    "\n",
    "    Returns:\n",
    "        Dictionary of `tf.Keras.layers.Input` layers for each feature.\n",
    "    \"\"\"\n",
    "    inputs = {\n",
    "        colname: tf.keras.layers.Input(\n",
    "            name=colname, shape=(), dtype=\"float32\")\n",
    "        for colname in [\"mother_age\", \"gestation_weeks\"]}\n",
    "\n",
    "    inputs.update({\n",
    "        colname: tf.keras.layers.Input(\n",
    "            name=colname, shape=(), dtype=\"string\")\n",
    "        for colname in [\"is_male\", \"plurality\", \"mother_race\"]})\n",
    "\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And set up feature columns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def categorical_fc(name, values):\n",
    "    cat_column = fc.categorical_column_with_vocabulary_list(\n",
    "            key=name, vocabulary_list=values)\n",
    "\n",
    "    return fc.indicator_column(categorical_column=cat_column)\n",
    "\n",
    "\n",
    "def create_feature_columns():\n",
    "    feature_columns = {\n",
    "        colname : fc.numeric_column(key=colname)\n",
    "           for colname in [\"mother_age\", \"gestation_weeks\"]\n",
    "    }\n",
    "\n",
    "    feature_columns[\"is_male\"] = categorical_fc(\n",
    "        \"is_male\", [\"True\", \"False\", \"Unknown\"])\n",
    "    feature_columns[\"plurality\"] = categorical_fc(\n",
    "        \"plurality\", [\"Single(1)\", \"Twins(2)\", \"Triplets(3)\",\n",
    "                      \"Quadruplets(4)\", \"Quintuplets(5)\", \"Multiple(2+)\"])\n",
    "    feature_columns[\"mother_race\"] = fc.indicator_column(\n",
    "        fc.categorical_column_with_hash_bucket(\n",
    "            \"mother_race\", hash_bucket_size=17, dtype=tf.dtypes.string))\n",
    "    \n",
    "    feature_columns[\"gender_x_plurality\"] = fc.embedding_column(\n",
    "        fc.crossed_column([\"is_male\", \"plurality\"], hash_bucket_size=18),\n",
    "        dimension=2)\n",
    "\n",
    "    return feature_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_outputs(inputs):\n",
    "    # Create two hidden layers of [64, 32] just in like the BQML DNN\n",
    "    h1 = layers.Dense(64, activation=\"relu\", name=\"h1\")(inputs)\n",
    "    h2 = layers.Dense(32, activation=\"relu\", name=\"h2\")(h1)\n",
    "\n",
    "    # Final output is a linear activation because this is regression\n",
    "    output = layers.Dense(units=1, activation=\"linear\", name=\"weight\")(h2)\n",
    "\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse(y_true, y_pred):\n",
    "    return tf.sqrt(tf.reduce_mean((y_pred - y_true) ** 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the model and set up distribution strategy\n",
    "\n",
    "Next, we'll combine the components of the model above to build the DNN model. \n",
    "\n",
    "Here is also where we'll define the distribution strategy. To do that, we'll place the building of the model inside the scope of the distribution strategy. Notice the output after excuting the cell below. We'll see\n",
    "```\n",
    "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n",
    "```\n",
    "This indicates that we're using the MirroredStrategy on 4 GPUs. That is because my machine has 4 GPUs. Your output may look different depending on how many GPUs you have on your device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "Here is our DNN architecture so far:\n",
      "\n",
      "Model: \"model_1\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "gestation_weeks (InputLayer)    [(None,)]            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "is_male (InputLayer)            [(None,)]            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "mother_age (InputLayer)         [(None,)]            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "mother_race (InputLayer)        [(None,)]            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "plurality (InputLayer)          [(None,)]            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "dense_features_1 (DenseFeatures (None, 30)           36          gestation_weeks[0][0]            \n",
      "                                                                 is_male[0][0]                    \n",
      "                                                                 mother_age[0][0]                 \n",
      "                                                                 mother_race[0][0]                \n",
      "                                                                 plurality[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "h1 (Dense)                      (None, 64)           1984        dense_features_1[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "h2 (Dense)                      (None, 32)           2080        h1[0][0]                         \n",
      "__________________________________________________________________________________________________\n",
      "weight (Dense)                  (None, 1)            33          h2[0][0]                         \n",
      "==================================================================================================\n",
      "Total params: 4,133\n",
      "Trainable params: 4,133\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "def build_dnn_model():\n",
    "    \"\"\"Builds simple DNN using Keras Functional API.\n",
    "\n",
    "    Returns:\n",
    "        `tf.keras.models.Model` object.\n",
    "    \"\"\"\n",
    "    # Create input layer\n",
    "    inputs = create_input_layers()\n",
    "\n",
    "    # Create feature columns\n",
    "    feature_columns = create_feature_columns()\n",
    "\n",
    "    # The constructor for DenseFeatures takes a list of numeric columns\n",
    "    # The Functional API in Keras requires: LayerConstructor()(inputs)\n",
    "    dnn_inputs = layers.DenseFeatures(\n",
    "        feature_columns=feature_columns.values())(inputs)\n",
    "\n",
    "    # Get output of model given inputs\n",
    "    output = get_model_outputs(dnn_inputs)\n",
    "\n",
    "    # Build model and compile it all together\n",
    "    model = tf.keras.models.Model(inputs=inputs, outputs=output)\n",
    "    model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[rmse, \"mse\"])\n",
    "\n",
    "    return model\n",
    "\n",
    "# Create the distribution strategy\n",
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "with mirrored_strategy.scope():\n",
    "    model = build_dnn_model()\n",
    "\n",
    "print(\"Here is our DNN architecture so far:\\n\")\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see how many GPU devices you have attached to your machine, run the cell below. As mentioned above, I have 4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of devices: 4\n"
     ]
    }
   ],
   "source": [
    "print('Number of devices: {}'.format(mirrored_strategy.num_replicas_in_sync))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
